package com.gcloud.mesh.ml;

import java.security.SecureRandom;
import java.text.DecimalFormat;

public class LinearRegression implements IMachineLearning {

    private TrainingSet ts = null;
    
    private static SecureRandom rd = new SecureRandom();

    public LinearRegression() { }

    /**
     * 添加模拟数据 testing.
     * */
    static void mock(TrainingSet ts, double x1, double x2) {
        // 模拟 y = theta0·x0 + theta1·x1 + theta2·x2
        // 其中 x0 永远为1, theta0为偏致参数
        // 初始权重参数 theta0 1 2，正常情况下最后学习得到的权重参数是跟下面一样的
        final double theta0 = 3.2, theta1 = 1.5, theta2 = 0.3;
        final double x0 = 1;
        double y = theta0*x0 + theta1*x1 + theta2*x2;
        ts.add(y, x0, x1, x2);
    }

    @Override
    public void setTrainingData(TrainingSet data) {
        this.ts = data;
        this.initTheta(this.ts);
    }

    @Override
    public void training() throws Exception {
        if (this.ts == null) {
            throw new Exception("training data is not setup yet.");
        }
        this.study();
    }

    @Override
    public TrainingSet.Data predict(TrainingSet.Data d) {
        TrainingSet.Data data = d;
        data.y = this.hypothesis(data.x);
        return data;
    }

    // testing...
    public static void main(String[] args) {
        TrainingSet ts = new TrainingSet();
        //生成100组模拟数据
        for(int i = 0; i < 100; i++) {
            mock(ts, rd.nextDouble(), rd.nextDouble());
        }

        LinearRegression lr = new LinearRegression();
        lr.setTrainingData(ts);
        lr.study();
    }

    // 学习到的参数权重
    private double[] theta;
    // 学习率
    private final double alpha = 0.001;

    private double stopThreshold = 0.0;

    boolean trained = false;

    public boolean trained() {
        return this.trained;
    }

    /** 启动学习 */
    void study(/*TrainingSet ts*/) {
//        initTheta(ts);

        long cycle_times = 0;
        double costMin = Double.MAX_VALUE;
        this.trained = false;

        while(true) {
            double cost = calculateCost(this.ts);
            if(cost < costMin) {
                // 权重的导数
                double[] delta = new double[theta.length];
                cycle_times++;
                System.out.println("cycle times:" + cycle_times);
                for(int i = 0; i < theta.length; i++) {
                    delta[i] = calculateDelta(this.ts, i);
                    System.out.println("delta"+i+" = " + D(delta[i]) + ", cost=" + D(cost) + ", θ"+i+"=" + D(theta[i]));
                }
                // move a little step.... alpha is study rate.
                for(int i = 0; i < theta.length; i++) {
                    theta[i] = theta[i] - alpha * delta[i];
                }
                costMin = cost;
            } else {
                System.out.println("最终得到不变的损失值，退出学习, cost：" + D(cost) + " costMin:" + D(costMin));
                break;
            }
        }
        this.trained = true;
        System.out.println("本轮学习得到的θ为：" + DA(theta));
    }

    /** 初始化theta */
    void initTheta(TrainingSet ts) {
        if (null == theta) {
            theta = new double[ts.dataList.get(0).x.length];
        }
    }

    /** 预测函数 */
    double hypothesis(double[] x) {
        //return θ[0] * x[0] + θ[1] * x[1] + θ[2] * x[2] + ...;
        double value = 0;
        for (int i = 0;  i < x.length; i++) {
            value += theta[i] * x[i];
        }
        return value;
    }

    /** 计算代价, loss function */
    double calculateCost(TrainingSet ts) {
        // 代价函数 J(θ0,θ1...θn) = Σ[i=1~m](h(x_i) - y_i)² / 2m
        // m 代表 m组数据
        double variance = 0;
        for(TrainingSet.Data data : ts.dataList) {
            variance += Math.pow(hypothesis(data.x) - data.y, 2);
        }
        return variance / (2 * ts.dataList.size());
    }

    /** 计算代价函数的偏导数
     * @param i 对θi求偏导 */
    double calculateDelta(TrainingSet ts, int i) {
        double sum = 0;
        for (TrainingSet.Data data : ts.dataList) {
            sum += 2 * data.x[i] * ( hypothesis(data.x) - data.y );
        }
        return sum;
    }

    static DecimalFormat fmt = new DecimalFormat("#.############");
    static String D(double val) {
        return fmt.format(val);
    }

    static String DA(double[] vals) {
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < vals.length; i++) {
            sb.append("θ" + i + " = " + fmt.format(vals[i]) + " , ");
        }
        return sb.toString();
    }
}